from engine import Utterance
from .widgets import get_figure_widget, FigureWidget

import matplotlib.pyplot as plt
from PySide6.QtCore import Qt, QStringListModel
from PySide6.QtGui import QImage, QPixmap
from PySide6.QtWidgets import *

from pathlib import Path
from typing import List, Set
import sounddevice as sd
import soundfile as sf
import numpy as np
from time import sleep
import umap
import sys
from warnings import filterwarnings, warn
filterwarnings("ignore")


colormap = np.array([
  [0, 127, 70],
  [255, 0, 0],
  [255, 217, 38],
  [0, 135, 255],
  [165, 0, 165],
  [255, 167, 255],
  [97, 142, 151],
  [0, 255, 255],
  [255, 96, 38],
  [142, 76, 0],
  [33, 0, 127],
  [0, 0, 0],
  [183, 183, 183],
  [76, 255, 0],
], dtype=np.float) / 255


class GUI(QDialog):
  min_umap_points = 4
  max_log_lines = 5
  max_saved_utterances = 20

  def draw_utterance(self, utterance: Utterance, which):
    self.draw_mel(utterance.mel, which)
    # self.draw_embed(utterance.spk_emb, Path(utterance.path).stem, which)

  def draw_embed(self, spk_emb, name, which):
    widget = self.cur_ax_widget if which == "current" else self.gen_ax_widget
    embed_ax, _ = widget.axis
    embed_ax.figure.suptitle("" if spk_emb is None else name)

    ## Embedding
    # Clear the plot
    if len(embed_ax.images) > 0:
      embed_ax.images[0].colorbar.remove()
    embed_ax.clear()

    # Draw speaker embedding
    if spk_emb is not None:
      embed_ax.set_title("embedding")
    embed_ax.set_aspect("equal", "datalim")
    embed_ax.set_xticks([])
    embed_ax.set_yticks([])
    embed_ax.figure.canvas.draw()
    widget.update()

  def draw_mel(self, mel, which):
    widget = self.cur_ax_widget if which == "current" else self.gen_ax_widget
    # _, mel_ax = widget.axis
    mel_ax = widget.axis

    ## Spectrogram
    # Draw the spectrogram
    mel_ax.clear()
    if mel is not None:
      im = mel_ax.imshow(mel, aspect="auto", origin="lower", interpolation='none')
      mel_ax.set_title("mel spectrogram")

    mel_ax.set_xticks([])
    mel_ax.set_yticks([])
    mel_ax.figure.canvas.draw()
    widget.update()
    if which != "current":
      self.vocode_button.setDisabled(mel is None)

  def draw_umap_projections(self, utterances: Set[Utterance]):
    self.umap_ax.clear()

    speakers = np.unique([u.spk_name for u in utterances])
    colors = {spk_name: colormap[i] for i, spk_name in enumerate(speakers)}
    embeds = [u.spk_emb for u in utterances]

    # Display a message if there aren't enough points
    if len(utterances) < self.min_umap_points:
      self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
                (self.min_umap_points - len(utterances)),
                horizontalalignment='center', fontsize=15)
      self.umap_ax.set_title("")

    # Compute the projections
    else:
      if not self.umap_hot:
        self.log(
          "Drawing UMAP projections for the first time, this will take a few seconds.")
        self.umap_hot = True

      reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
      projections = reducer.fit_transform(embeds)

      speakers_done = set()
      for projection, utterance in zip(projections, utterances):
        color = colors[utterance.spk_name]
        mark = "x" if "_gen_" in Path(utterance.path).stem else "o"
        label = None if utterance.spk_name in speakers_done else utterance.spk_name
        speakers_done.add(utterance.spk_name)
        self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark, label=label)
      # self.umap_ax.set_title("UMAP projections")
      self.umap_ax.legend(prop={'size': 10})

    # Draw the plot
    self.umap_ax.set_aspect("equal", "datalim")
    self.umap_ax.set_xticks([])
    self.umap_ax.set_yticks([])
    self.umap_ax.figure.canvas.draw()

  def save_audio_file(self, wav, sample_rate):
    dialog = QFileDialog()
    dialog.setDefaultSuffix(".wav")
    fpath, _ = dialog.getSaveFileName(
      parent=self,
      caption="Select a path to save the audio file",
      filter="Audio Files (*.flac *.wav)"
    )
    if fpath:
      #Default format is wav
      if Path(fpath).suffix == "":
        fpath += ".wav"
      sf.write(fpath, wav, sample_rate)

  def setup_audio_devices(self, sample_rate):
    input_devices = []
    output_devices = []
    for device in sd.query_devices():
      # Check if valid input
      try:
        sd.check_input_settings(device=device["name"], samplerate=sample_rate)
        input_devices.append(device["name"])
      except:
        pass

      # Check if valid output
      try:
        sd.check_output_settings(device=device["name"], samplerate=sample_rate)
        output_devices.append(device["name"])
      except Exception as e:
        # Log a warning only if the device is not an input
        if not device["name"] in input_devices:
          warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))

    if len(input_devices) == 0:
      self.log("No audio input device detected. Recording may not work.")
      self.audio_in_device = None
    else:
      self.audio_in_device = input_devices[0]

    if len(output_devices) == 0:
      self.log("No supported output audio devices were found! Audio output may not work.")
      self.audio_out_devices_cb.addItems(["None"])
      self.audio_out_devices_cb.setDisabled(True)
    else:
      self.audio_out_devices_cb.clear()
      self.audio_out_devices_cb.addItems(output_devices)
      self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)

    self.set_audio_device()

  def set_audio_device(self):

    output_device = self.audio_out_devices_cb.currentText()
    if output_device == "None":
      output_device = None

    # If None, sounddevice queries portaudio
    sd.default.device = (self.audio_in_device, output_device)

  def play(self, wav, sample_rate):
    try:
      sd.stop()
      sd.play(wav, sample_rate)
    except Exception as e:
      print(e)
      self.log("Error in audio playback. Try selecting a different audio output device.")
      self.log("Your device must be connected before you start the toolbox.")

  def stop(self):
    sd.stop()

  def record_one(self, sample_rate, duration):
    self.record_button.setText("Recording...")
    self.record_button.setDisabled(True)

    self.log("Recording %d seconds of audio" % duration)
    sd.stop()
    try:
      wav = sd.rec(duration * sample_rate, sample_rate, 1)
    except Exception as e:
      print(e)
      self.log("Could not record anything. Is your recording device enabled?")
      self.log("Your device must be connected before you start the toolbox.")
      return None

    for i in np.arange(0, duration, 0.1):
      self.set_loading(i, duration)
      sleep(0.1)
    self.set_loading(duration, duration)
    sd.wait()

    self.log("Done recording.")
    self.record_button.setText("Record")
    self.record_button.setDisabled(False)

    return wav.squeeze()

  @property
  def current_dataset_name(self):
    return self.dataset_box.currentText()

  @property
  def current_src_spk(self):
    return self.src_spk_box.currentText()

  @property
  def current_tgt_spk(self):
    return self.tgt_spk_box.currentText()

  @property
  def current_utterance_name(self):
    return self.utterance_box.currentText()

  def browse_file(self):
    fpath = QFileDialog().getOpenFileName(
      parent=self,
      caption="Select an audio file",
      filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
    )
    return Path(fpath[0]) if fpath[0] != "" else ""

  @staticmethod
  def repopulate_box(box, items, random=False):
    """
    Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
    data to the items
    """
    box.blockSignals(True)
    box.clear()
    for item in items:
      item = list(item) if isinstance(item, tuple) else [item]
      box.addItem(str(item[0]), *item[1:])
    if len(items) > 0:
      box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
    box.setDisabled(len(items) == 0)
    box.blockSignals(False)

  def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int, random=True):
    # Select a random dataset
    if level <= 0:
      if datasets_root is not None:
        datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
        datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
        self.browser_load_button.setDisabled(len(datasets) == 0)
      if datasets_root is None or len(datasets) == 0:
        msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
          if datasets_root is None else "o not have any of the recognized datasets" \
                          " in %s" % datasets_root)
        self.log(msg)
        msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
             "can still use the toolbox by recording samples yourself." % \
             ("\n\t".join(map(str, recognized_datasets)))
        print(msg, file=sys.stderr)

        self.random_utterance_button.setDisabled(True)
        self.random_speaker_button.setDisabled(True)
        self.random_dataset_button.setDisabled(True)
        self.utterance_box.setDisabled(True)
        self.src_spk_box.setDisabled(True)
        self.tgt_spk_box.setDisabled(True)
        self.dataset_box.setDisabled(True)
        self.browser_load_button.setDisabled(True)
        self.auto_next_checkbox.setDisabled(True)
        return
      self.repopulate_box(self.dataset_box, datasets, random)

    # Select a random src and tgt speakers
    if level <= 1:
      speakers_root = datasets_root.joinpath(self.current_dataset_name)
      speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
      self.repopulate_box(self.src_spk_box, speaker_names, random)
      self.repopulate_box(self.tgt_spk_box, speaker_names, random)

    # Select a random utterance
    if level <= 2:
      utterances_root = datasets_root.joinpath(
        self.current_dataset_name,
        self.current_src_spk
      )
      utterances = []
      for extension in ['mp3', 'flac', 'wav']:
        utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
      utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
      self.repopulate_box(self.utterance_box, utterances, random)

  def browser_select_next(self):
    index = (self.utterance_box.currentIndex() + 1) % self.utterance_box.count()
    self.utterance_box.setCurrentIndex(index)

  @property
  def selected_utterance(self):
    return self.utterance_history.itemData(self.utterance_history.currentIndex())

  def register_utterance(self, utterance: Utterance):
    self.utterance_history.blockSignals(True)
    self.utterance_history.insertItem(0, Path(utterance.path).stem, utterance)
    self.utterance_history.setCurrentIndex(0)
    self.utterance_history.blockSignals(False)

    if self.utterance_history.count() > self.max_saved_utterances:
      self.utterance_history.removeItem(self.max_saved_utterances)

    self.play_button.setDisabled(False)
    self.generate_button.setDisabled(False)
    self.synthesize_button.setDisabled(False)

  def log(self, line, mode="newline"):
    if mode == "newline":
      self.logs.append(line)
      if len(self.logs) > self.max_log_lines:
        del self.logs[0]
    elif mode == "append":
      self.logs[-1] += line
    elif mode == "overwrite":
      self.logs[-1] = line
    log_text = '\n'.join(self.logs)

    self.log_window.setText(log_text)
    self.app.processEvents()

  def set_loading(self, value, maximum=1):
    self.loading_bar.setValue(value * 100)
    self.loading_bar.setMaximum(maximum * 100)
    self.loading_bar.setTextVisible(value != 0)
    self.app.processEvents()

  def populate_gen_options(self, seed, trim_silences):
    if seed is not None:
      self.random_seed_checkbox.setChecked(True)
      self.seed_textbox.setText(str(seed))
      self.seed_textbox.setEnabled(True)
    else:
      self.random_seed_checkbox.setChecked(False)
      self.seed_textbox.setText(str(0))
      self.seed_textbox.setEnabled(False)

  def update_seed_textbox(self):
    if self.random_seed_checkbox.isChecked():
      self.seed_textbox.setEnabled(True)
    else:
      self.seed_textbox.setEnabled(False)

  def reset_interface(self):
    # self.draw_embed(None, None, "current")
    # self.draw_embed(None, None, "generated")
    self.draw_mel(None, "current")
    self.draw_mel(None, "generated")
    # self.draw_umap_projections(set())
    self.set_loading(0)
    self.play_button.setDisabled(True)
    self.generate_button.setDisabled(True)
    self.synthesize_button.setDisabled(True)
    self.vocode_button.setDisabled(True)
    self.replay_wav_button.setDisabled(True)
    self.export_wav_button.setDisabled(True)
    [self.log("") for _ in range(self.max_log_lines)]

  def __init__(self):
    ## Initialize the application
    self.app = QApplication(sys.argv)
    super().__init__(None)
    self.setWindowTitle("Voice Conversion app")


    ## Main layouts
    # Root
    root_layout = QGridLayout()
    self.setLayout(root_layout)

    # Browser
    browser_layout = QGridLayout()
    root_layout.addLayout(browser_layout, 0, 0, 1, 2)

    # Generation
    gen_layout = QVBoxLayout()
    root_layout.addLayout(gen_layout, 1, 0, 1, 2)

    # Projections
    self.projections_layout = QVBoxLayout()
    root_layout.addLayout(self.projections_layout, 1, 2, 1, 1)

    # Visualizations
    vis_layout = QVBoxLayout()
    root_layout.addLayout(vis_layout, 0, 2, 1, 1)


    ## Projections
    # UMap
    fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
    fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
    self.projections_layout.addWidget(get_figure_widget(fig))
    self.umap_hot = False
    self.clear_button = QPushButton("Clear")
    self.projections_layout.addWidget(self.clear_button)


    ## Browser
    # Dataset, speaker and utterance selection
    i = 0
    self.dataset_box = QComboBox()
    browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
    browser_layout.addWidget(self.dataset_box, i + 1, 0)
    self.src_spk_box = QComboBox()
    browser_layout.addWidget(QLabel("<b>Source speaker</b>"), i, 1)
    browser_layout.addWidget(self.src_spk_box, i + 1, 1)
    self.utterance_box = QComboBox()
    browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
    browser_layout.addWidget(self.utterance_box, i + 1, 2)
    self.browser_load_button = QPushButton("Load")
    browser_layout.addWidget(self.browser_load_button, i + 1, 3)
    i += 2

    # Random buttons
    self.random_dataset_button = QPushButton("Random")
    browser_layout.addWidget(self.random_dataset_button, i, 0)
    self.random_speaker_button = QPushButton("Random")
    browser_layout.addWidget(self.random_speaker_button, i, 1)
    self.random_utterance_button = QPushButton("Random")
    browser_layout.addWidget(self.random_utterance_button, i, 2)
    self.auto_next_checkbox = QCheckBox("Auto select next")
    self.auto_next_checkbox.setChecked(True)
    browser_layout.addWidget(self.auto_next_checkbox, i, 3)
    i += 1

    # Utterance box
    browser_layout.addWidget(QLabel("<b>Use source from:</b>"), i, 0)
    self.utterance_history = QComboBox()
    browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
    i += 1

    # Random & next utterance buttons
    self.browser_browse_button = QPushButton("Browse")
    browser_layout.addWidget(self.browser_browse_button, i, 0)
    self.record_button = QPushButton("Record")
    browser_layout.addWidget(self.record_button, i, 1)
    self.play_button = QPushButton("Play")
    browser_layout.addWidget(self.play_button, i, 2)
    self.stop_button = QPushButton("Stop")
    browser_layout.addWidget(self.stop_button, i, 3)
    i += 1


    # Model and audio output selection
    self.tgt_spk_box = QComboBox()
    browser_layout.addWidget(QLabel("<b>Target speaker</b>"), i, 0)
    browser_layout.addWidget(self.tgt_spk_box, i + 1, 0)

    self.audio_out_devices_cb=QComboBox()
    browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 1)
    browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 1)
    i += 2

    #Replay & Save Audio
    browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
    self.wavs_cb = QComboBox()
    self.wavs_cb_model = QStringListModel()
    self.wavs_cb.setModel(self.wavs_cb_model)
    self.wavs_cb.setToolTip("Select one of the last generated wavs in this section for replaying or exporting")
    browser_layout.addWidget(self.wavs_cb, i, 1)
    self.replay_wav_button = QPushButton("Replay")
    self.replay_wav_button.setToolTip("Replay last generated vocoder")
    browser_layout.addWidget(self.replay_wav_button, i, 2)
    self.export_wav_button = QPushButton("Export")
    self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
    browser_layout.addWidget(self.export_wav_button, i, 3)
    i += 1


    ## Embed & spectrograms
    vis_layout.addStretch()

    gridspec_kw = {"width_ratios": [1]}
    fig, cur_ax = plt.subplots(
      1, 1, figsize=(5, 2), gridspec_kw=gridspec_kw
    )
    fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
    self.cur_ax_widget = FigureWidget(fig, cur_ax)
    vis_layout.addWidget(self.cur_ax_widget)

    fig, gen_ax = plt.subplots(
      1, 1, figsize=(5, 2), gridspec_kw=gridspec_kw
    )
    fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
    self.gen_ax_widget = FigureWidget(fig, gen_ax)
    vis_layout.addWidget(self.gen_ax_widget)

    # for ax in self.cur_ax_widget.axis.tolist() + self.gen_ax_widget.axis.tolist():
    for ax in [self.cur_ax_widget.axis, self.gen_ax_widget.axis]:
      ax.set_facecolor("#F0F0F0")
      for side in ["top", "right", "bottom", "left"]:
        ax.spines[side].set_visible(False)


    ## Generation
    layout = QHBoxLayout()
    self.generate_button = QPushButton("Synthesize and vocode")
    layout.addWidget(self.generate_button)
    self.synthesize_button = QPushButton("Synthesize only")
    layout.addWidget(self.synthesize_button)
    self.vocode_button = QPushButton("Vocode only")
    layout.addWidget(self.vocode_button)
    gen_layout.addLayout(layout)

    layout_seed = QGridLayout()
    self.random_seed_checkbox = QCheckBox("Random seed:")
    self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
    layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
    self.seed_textbox = QLineEdit()
    self.seed_textbox.setMaximumWidth(80)
    layout_seed.addWidget(self.seed_textbox, 0, 1)
    gen_layout.addLayout(layout_seed)

    self.loading_bar = QProgressBar()
    gen_layout.addWidget(self.loading_bar)

    self.log_window = QLabel()
    self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
    gen_layout.addWidget(self.log_window)
    self.logs = []
    gen_layout.addStretch()


    ## Set the size of the window and of the elements
    max_size = self.screen().availableGeometry().size() * 0.7
    self.resize(max_size)

    ## Finalize the display
    self.reset_interface()
    self.show()

  def start(self):
    self.app.exec_()